from flask import Flask, request
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from jinja2.utils import url_quote
from html import escape as html_escape
from .tts import text_to_speech
from .format import pcm2wav, pcm2mp3, pcm2ogg

app = Flask(__name__)

def cost_function():
    if request.endpoint == 'tts':
        text = request.args.get('text', '')
        quality = request.args.get('quality', 'medium')
        format = request.args.get('format', 'wav')
        speed = request.args.get('speed', '1.0')
        try:
            speed = float(speed)
        except:
            speed = 1.0
        length = len(text.encode())
        if quality == 'low':
            cost = round(length / 100)
        elif quality == 'medium':
            cost = round(length / 40)
        else:
            cost = round(length / 30)
        if speed < 0.75:
            cost = cost * 3 // 2
        if format in ('mp3', 'ogg'):
            cost = cost * 4 // 3
        print(cost)
        return 1 + cost
    return 1

limiter = Limiter(
    get_remote_address,
    app=app,
    default_limits=["100/minute", "2000/hour"],
    default_limits_cost=cost_function,
    storage_uri="memory://",
)

@app.route('/', methods=['GET', 'POST'])
def home():
    if request.method == 'POST':
        text = request.form['text']
        speed = float(request.form['speed'])
        volume = float(request.form['volume'])
        pitch = float(request.form['pitch'])
        silence = float(request.form['silence'])
        quality = request.form['quality']
    else:
        text = '注意看，这个男人叫里该隐。'
        speed = 1.0
        volume = 0.3
        pitch = 0.0
        silence = 0.25
        quality = 'medium'
    return f'''<html lang="cn"><head>
<title>小彭老师语音合成服务</title>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
</head><body>
<h1 style="color:grey">欢迎使用小彭老师语音合成服务！</h1>
    <form action="/" method="POST">
        <textarea name="text" rows="10" cols="36" placeholder="输入要朗读的文本" required>{html_escape(text)}</textarea>
        <br/>
        <select name="speed">
            <option value="0.5"{' selected' if speed == 0.5 else ''}>0.5x</option>
            <option value="0.75"{' selected' if speed == 0.75 else ''}>0.75x</option>
            <option value="1.0"{' selected' if speed == 1 else ''}>1.0x</option>
            <option value="1.25"{' selected' if speed == 1.25 else ''}>1.25x</option>
            <option value="1.5"{' selected' if speed == 1.5 else ''}>1.5x</option>
            <option value="2.0"{' selected' if speed == 2.0 else ''}>2.0x</option>
        </select>
        <select name="volume">
            <option value="0.1"{' selected' if volume == 0.1 else ''}>10%</option>
            <option value="0.2"{' selected' if volume == 0.2 else ''}>20%</option>
            <option value="0.3"{' selected' if volume == 0.3 else ''}>30%</option>
            <option value="0.5"{' selected' if volume == 0.5 else ''}>50%</option>
            <option value="0.75"{' selected' if volume == 0.75 else ''}>75%</option>
            <option value="1.0"{' selected' if volume == 1 else ''}>100%</option>
        </select>
        <select name="pitch">
            <option value="-3"{' selected' if pitch == -3 else ''}>-3</option>
            <option value="-2"{' selected' if pitch == -2 else ''}>-2</option>
            <option value="-1"{' selected' if pitch == -1 else ''}>-1</option>
            <option value="-0.5"{' selected' if pitch == -0.5 else ''}>-0.5</option>
            <option value="0"{' selected' if pitch == 0 else ''}>0</option>
            <option value="0.5"{' selected' if pitch == 0.5 else ''}>+0.5</option>
            <option value="1"{' selected' if pitch == 1 else ''}>+1</option>
            <option value="2"{' selected' if pitch == 2 else ''}>+2</option>
            <option value="3"{' selected' if pitch == 3 else ''}>+3</option>
        </select>
        <select name="silence">
            <option value="0.0"{' selected' if silence == 0 else ''}>0.0s</option>
            <option value="0.25"{' selected' if silence == 0.25 else ''}>0.25s</option>
            <option value="0.5"{' selected' if silence == 0.5 else ''}>0.5s</option>
            <option value="1.0"{' selected' if silence == 1 else ''}>1.0s</option>
            <option value="1.5"{' selected' if silence == 1.5 else ''}>1.5s</option>
        </select>
        <select name="quality">
            <option value="low"{' selected' if quality == 'low' else ''}>low</option>
            <option value="medium"{' selected' if quality == 'medium' else ''}>medium</option>
            <option value="high"{' selected' if quality == 'high' else ''}>high</option>
        </select>
        <input type="submit" value="朗读"/>
        <audio autoplay src="/api/tts?text={url_quote(text)}&speed={speed}&volume={volume}&pitch={pitch}&silence={silence}&quality={quality}&format=wav"></audio>
        <a href="/api/tts?text={url_quote(text)}&speed={speed}&volume={volume}&pitch={pitch}&silence={silence}&quality={quality}&format=wav">下载</a>
    <form>
    <br/>
<pre><code style="color:lightgrey">
curl -sL "{request.host_url}api/tts?text={url_quote(text) if text and len(text) < 50 else "Your+text+here"}&speed={speed}&volume={volume}&pitch={pitch}&silence={silence}&quality={quality}&format=wav" -o output.wav
ffplay -nodisp -autoexit output.wav
</code></pre>
<p style="font-size:14px;color:lightgrey">请求速率限制：每分钟 100 请求，每小时 2000 请求</p>
<p style="font-size:14px;color:lightgrey">low：每 100 字节文本算作 1 请求</p>
<p style="font-size:14px;color:lightgrey">medium：每 40 字节文本算作 1 请求</p>
<p style="font-size:14px;color:lightgrey">high：每 30 字节文本算作 1 请求</p>
</body></html>'''

# def rate_limited(stream, limit):
#     try:
#         for chunk in stream:
#             with limiter.limit(limit):
#                 yield chunk
#     except RateLimitExceeded:
#         yield b'RateLimitExceeded!'

@app.route('/api/tts', methods=['GET'])
def tts():
    text = request.args.get('text', '')
    speed = request.args.get('speed', '1.0')
    volume = request.args.get('volume', '1.0')
    pitch = request.args.get('pitch', '0.0')
    silence = request.args.get('silence', '0.0')
    rate = request.args.get('rate', '0')
    quality = request.args.get('quality', 'medium')
    format = request.args.get('format', 'wav')
    if not text:
        return 'Text is required.', 400
    if len(text.encode()) > 2000:
        return 'Text must be less than 2000 bytes.', 400
    if not format:
        format = 'wav'
    if format not in ('pcm', 'wav', 'mp3', 'ogg'):
        return 'Invalid format.', 400
    if not quality:
        quality = 'medium'
    if quality not in ('low', 'medium', 'high'):
        return 'Invalid quality.', 400
    try:
        speed = float(speed)
    except:
        speed = 1.0
    try:
        volume = float(volume)
    except:
        volume = 1.0
    try:
        pitch = float(pitch)
    except:
        pitch = 0.0
    try:
        silence = float(silence)
    except:
        silence = 0.0
    try:
        rate = int(rate)
    except:
        rate = 22050
    stream, rate = text_to_speech(text, speed, volume, pitch, silence, rate, quality)
    if format == 'pcm':
        return app.response_class(stream, mimetype=f'audio/pcm;rate={rate}')
    else:
        pcm = b''
        for chunk in stream:
            pcm += chunk
        conv = {
            'wav': pcm2wav,
            'mp3': pcm2mp3,
            'ogg': pcm2ogg,
        }.get(format, pcm2wav)
        return app.response_class(conv(pcm, rate), mimetype=f'audio/{format}')

# @app.route('/status')
# def status():
#     import psutil
#     # query cpu / memory usage status
#     return f'''<html lang="cn"><head>
# <meta charset="utf-8">
# <meta http-equiv="refresh" content="3">
# <title>小彭老师语音合成服务运行状况</title>
# </head><body>
# <h3>运行状况</h3>
# <table>
# <tr><td>处理器占用率：</td><td>{psutil.cpu_percent()}%</td></tr>
# <tr><td>内存占用率：</td><td>{psutil.virtual_memory().percent}%</td></tr>
# <tr><td>交换区占用率：</td><td>{psutil.swap_memory().percent}%</td></tr>
# <tr><td>磁盘占用率：</td><td>{psutil.disk_usage('/').percent}%</td></tr>
# <tr><td>网络连接情况：</td><td>{len(psutil.net_connections())} 个连接</td></tr>
# <tr><td>网络 IO：</td><td>{str(psutil.net_io_counters())}</td></tr>
# </table>
# </body></html>'''

if __name__ == '__main__':
    app.run(debug=True)
